from segformer.cd_transformer import Segformer
from main import build_model, parse_args

import torch

if __name__ == "__main__":
    args = parse_args()
    model = build_model(args)
    x = torch.randn(2, 4, 512, 512)
    x = model(x, x)
    print(x.shape)
